Skip to content

Commit b81eb06

Browse files
committed
feat: support cancel of OpFuture
1 parent e7fefa4 commit b81eb06

4 files changed

Lines changed: 86 additions & 62 deletions

File tree

src/duron/event_loop.py

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import contextvars
66
import heapq
77
import logging
8-
import threading
98
from asyncio import AbstractEventLoop, Handle, Task, TimerHandle, events
109
from collections import deque
1110
from dataclasses import dataclass
@@ -86,7 +85,6 @@ def __init__(self, seed: bytes) -> None:
8685
self._closed: bool = False
8786
self._event: asyncio.Event = asyncio.Event()
8887

89-
self._lock: threading.Lock = threading.Lock()
9088
self._timers: list[TimerHandle] = []
9189

9290
def _generate_id(self) -> bytes:
@@ -116,19 +114,9 @@ def call_soon_threadsafe(
116114
callback: Callable[[Unpack[_Ts]], object],
117115
*args: Unpack[_Ts],
118116
context: Context | None = None,
119-
task_id: bytes | None = None,
120117
) -> Handle:
121-
self._event.set()
122-
h = TimerHandle(
123-
0,
124-
callback,
125-
args,
126-
self,
127-
context=self._context_with_task_id(context, task_id=task_id),
128-
)
129-
with self._lock:
130-
heapq.heappush(self._timers, h)
131-
return h
118+
# the loop is not thread-safe, so we just call call_soon
119+
return self.call_soon(callback, *args, context=context)
132120

133121
@override
134122
def call_at(
@@ -143,10 +131,9 @@ def call_at(
143131
callback,
144132
args,
145133
loop=self,
146-
context=self._context_with_task_id(context),
134+
context=context,
147135
)
148-
with self._lock:
149-
heapq.heappush(self._timers, th)
136+
heapq.heappush(self._timers, th)
150137
return th
151138

152139
@override
@@ -179,7 +166,7 @@ def create_task(
179166
context: Context | None = None,
180167
**kwargs: Any,
181168
) -> Task[_T]:
182-
ctx = self._context_with_task_id(context)
169+
ctx = self._context_new_task(context, self._generate_id())
183170
return ctx.run(
184171
cast("type[Task[_T]]", Task), coro, name=name, loop=self, **kwargs
185172
)
@@ -193,17 +180,16 @@ def poll_completion(self, task: Future[_T]) -> WaitSet | None:
193180
deadline: float | None = None
194181
while True:
195182
deadline = None
196-
with self._lock:
197-
while self._timers:
198-
ht = self._timers[0]
199-
if ht.cancelled():
200-
_ = heapq.heappop(self._timers)
201-
elif ht.when() <= now:
202-
_ = heapq.heappop(self._timers)
203-
self._ready.append(ht)
204-
else:
205-
deadline = ht.when()
206-
break
183+
while self._timers:
184+
ht = self._timers[0]
185+
if ht.cancelled():
186+
_ = heapq.heappop(self._timers)
187+
elif ht.when() <= now:
188+
_ = heapq.heappop(self._timers)
189+
self._ready.append(ht)
190+
else:
191+
deadline = ht.when()
192+
break
207193

208194
if not self._ready:
209195
break
@@ -213,7 +199,7 @@ def poll_completion(self, task: Future[_T]) -> WaitSet | None:
213199
continue
214200
try:
215201
h._run()
216-
except BaseException as exc:
202+
except Exception as exc:
217203
self.call_exception_handler({
218204
"message": "exception in callback",
219205
"exception": exc,
@@ -237,20 +223,20 @@ def create_op(self, params: object) -> OpFuture:
237223
return s
238224

239225
@overload
240-
def post_completion_threadsafe(
226+
def post_completion(
241227
self,
242228
id: bytes,
243229
*,
244230
result: object,
245231
) -> None: ...
246232
@overload
247-
def post_completion_threadsafe(
233+
def post_completion(
248234
self,
249235
id: bytes,
250236
*,
251237
exception: BaseException,
252238
) -> None: ...
253-
def post_completion_threadsafe(
239+
def post_completion(
254240
self,
255241
id: bytes,
256242
*,
@@ -259,10 +245,15 @@ def post_completion_threadsafe(
259245
) -> None:
260246
if op := self._ops.pop(id, None):
261247
tid = _mix_id(op.id, -1)
262-
if exception is not None:
263-
_ = self.call_soon_threadsafe(op.set_exception, exception, task_id=tid)
264-
else:
265-
_ = self.call_soon_threadsafe(op.set_result, result, task_id=tid)
248+
self._event.set()
249+
self._ready.append(
250+
Handle(
251+
op.set_result if exception is None else op.set_exception,
252+
[(result if exception is None else exception)],
253+
self,
254+
context=self._context_new_task(None, tid),
255+
)
256+
)
266257

267258
@override
268259
def is_closed(self) -> bool:
@@ -319,12 +310,8 @@ async def shutdown_default_executor(self):
319310
def _timer_handle_cancelled(self, _th: TimerHandle) -> None:
320311
pass
321312

322-
def _context_with_task_id(
323-
self, context: Context | None, task_id: bytes | None = None
324-
) -> Context:
313+
def _context_new_task(self, context: Context | None, task_id: bytes) -> Context:
325314
base = context.copy() if context is not None else contextvars.copy_context()
326-
if task_id is None:
327-
task_id = self._generate_id()
328315
_ = base.run(_task_ctx.set, _TaskCtx(parent_id=task_id))
329316
return base
330317

src/duron/task.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from types import TracebackType
2828

2929
from duron.codec import Codec
30-
from duron.event_loop import WaitSet
30+
from duron.event_loop import OpFuture, WaitSet
3131
from duron.fn import DurableFn
3232
from duron.log import (
3333
Entry,
@@ -36,7 +36,6 @@
3636
LogStorage,
3737
PromiseCompleteEntry,
3838
)
39-
from duron.ops import Op
4039

4140

4241
_T = TypeVar("_T")
@@ -138,16 +137,17 @@ def __init__(
138137
self._codec = codec
139138
self._running: object | None = None
140139
self._pending_msg: list[Entry] = []
141-
self._pending_task: dict[str, Callable[[], Coroutine[Any, Any, object]]] = {}
140+
self._pending_task: dict[str, Callable[[], Coroutine[Any, Any, None]]] = {}
142141
self._pending_ops: set[bytes] = set()
143142
self._now = 0
144143
self._offset: object | None = None
144+
self._tasks: dict[str, asyncio.Task[None]] = {}
145145

146146
def now(self) -> int:
147147
if self._running:
148148
t = time.time_ns()
149149
t -= t % 1_000
150-
self._now = max(self._now, t)
150+
self._now = max(self._now + 1_000, t)
151151
return self._now
152152

153153
async def resume(self) -> None:
@@ -177,8 +177,8 @@ async def run(self) -> object:
177177
for msg in self._pending_msg:
178178
await self.enqueue_log(msg)
179179
self._pending_msg.clear()
180-
for task in self._pending_task.values():
181-
_ = asyncio.create_task(task())
180+
for key, task in self._pending_task.items():
181+
self._tasks[key] = asyncio.create_task(task())
182182
self._pending_task.clear()
183183

184184
t1 = asyncio.create_task(self._follow_log())
@@ -193,6 +193,11 @@ async def run(self) -> object:
193193
with contextlib.suppress(asyncio.CancelledError):
194194
await t
195195

196+
for t in self._tasks.values():
197+
_ = t.cancel()
198+
with contextlib.suppress(asyncio.CancelledError):
199+
await t
200+
196201
for t in done:
197202
if exc := t.exception():
198203
raise exc
@@ -223,20 +228,20 @@ async def _step(self) -> WaitSet | None:
223228
sid = s.id
224229
if sid not in self._pending_ops:
225230
self._pending_ops.add(sid)
226-
await self.enqueue_op(sid, s.params)
231+
await self.enqueue_op(sid, s)
227232

228233
async def handle_message(self, e: Entry) -> None:
229234
if e["type"] == "promise/complete":
230235
_ = self._pending_task.pop(e["promise_id"], None)
231236
id = _decode_id(e["promise_id"])
232237
if "error" in e:
233-
self._loop.post_completion_threadsafe(
238+
self._loop.post_completion(
234239
id,
235240
exception=_decode_error(e["error"]),
236241
)
237242
self._pending_ops.discard(id)
238243
elif "result" in e:
239-
self._loop.post_completion_threadsafe(
244+
self._loop.post_completion(
240245
id,
241246
result=self._codec.decode_json(e["result"]),
242247
)
@@ -251,11 +256,12 @@ async def enqueue_log(self, entry: Entry, flush: bool = False) -> None:
251256
self._pending_msg.append(entry)
252257
else:
253258
await self._log.append(self._running, entry)
254-
await self.handle_message(entry)
255259
if flush:
256260
await self._log.flush(self._running)
261+
await self.handle_message(entry)
257262

258-
async def enqueue_op(self, id: bytes, op: Op | object) -> None:
263+
async def enqueue_op(self, id: bytes, fut: OpFuture) -> None:
264+
op = fut.params
259265
match op:
260266
case FnCall():
261267
await self.enqueue_log({
@@ -278,12 +284,22 @@ async def cb() -> None:
278284
entry["result"] = self._codec.encode_json(result)
279285
except BaseException as e:
280286
entry["error"] = _encode_error(e)
281-
await self.enqueue_log(entry)
282-
287+
finally:
288+
await self.enqueue_log(entry)
289+
290+
def done(f: OpFuture) -> None:
291+
if f.cancelled():
292+
sid = _encode_id(f.id)
293+
_ = self._pending_task.pop(sid, None)
294+
if tsk := self._tasks.get(sid, None):
295+
_ = tsk.cancel()
296+
297+
fut.add_done_callback(done)
298+
sid = _encode_id(id)
283299
if self._running:
284-
_ = asyncio.create_task(cb())
300+
self._tasks[sid] = asyncio.create_task(cb())
285301
else:
286-
self._pending_task[_encode_id(id)] = cb
302+
self._pending_task[sid] = cb
287303

288304
case TaskRun():
289305
await self.enqueue_log({
@@ -304,8 +320,8 @@ async def cb() -> None:
304320
entry["result"] = self._codec.encode_json(result)
305321
except BaseException as e:
306322
entry["error"] = _encode_error(e)
307-
308-
await self.enqueue_log(entry, True)
323+
finally:
324+
await self.enqueue_log(entry, True)
309325

310326
_ = self._loop.create_task(cb())
311327

@@ -336,7 +352,7 @@ def _decode_timestamp(ts: int) -> int:
336352
def _encode_error(error: BaseException) -> ErrorInfo:
337353
return {
338354
"code": -1,
339-
"message": str(error),
355+
"message": repr(error),
340356
}
341357

342358

tests/test_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def tick(n: int, expect: set[int] | None) -> list[bytes] | None:
6868

6969
for i in range(5, 0, -1):
7070
ws = tick(i, {1, 2, 3, 4, 5})
71-
loop.post_completion_threadsafe(ws[random.randint(0, i - 1)], result=i - 1)
71+
loop.post_completion(ws[random.randint(0, i - 1)], result=i - 1)
7272
ws = tick(1, {6})
73-
loop.post_completion_threadsafe(ws[0], result=6)
73+
loop.post_completion(ws[0], result=6)
7474
tick(0, None)
7575

7676
return ids

tests/test_task.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import random
56
import uuid
67
from dataclasses import dataclass
@@ -108,6 +109,26 @@ async def activity(ctx: Context, s: str) -> str:
108109
assert x == "hello"
109110

110111

112+
@pytest.mark.asyncio
113+
async def test_cancel():
114+
@durable()
115+
async def activity(ctx: Context, s: str) -> str:
116+
with contextlib.suppress(BaseException):
117+
_ = await asyncio.wait_for(ctx.run(lambda: asyncio.sleep(9999)), 0.1)
118+
_ = await asyncio.wait_for(ctx.run(lambda: asyncio.sleep(9999)), 0.1)
119+
return s
120+
121+
log = MemoryLogStorage()
122+
async with activity(log) as t:
123+
await t.start("hello")
124+
try:
125+
_ = await t.wait()
126+
except Exception as e:
127+
assert "Timeout" in repr(e)
128+
129+
assert len(await log.entries()) == 8
130+
131+
111132
@dataclass
112133
class CustomPoint:
113134
x: int

0 commit comments

Comments
 (0)