|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from typing import TYPE_CHECKING |
| 5 | + |
| 6 | +if TYPE_CHECKING: |
| 7 | + from typing import Dict |
| 8 | + |
| 9 | + |
| 10 | +class TaskCancellationManager: |
| 11 | + def __init__(self) -> None: |
| 12 | + self.cancel_events: Dict[str, asyncio.Event] = {} |
| 13 | + self.background_tasks: Dict[str, asyncio.Task] = {} |
| 14 | + |
| 15 | + def register(self, task_id: str, bg_task: asyncio.Task) -> None: |
| 16 | + cancel_event = asyncio.Event() |
| 17 | + self.cancel_events[task_id] = cancel_event |
| 18 | + self.background_tasks[task_id] = bg_task |
| 19 | + |
| 20 | + def signal_cancel(self, task_id: str) -> None: |
| 21 | + if task_id in self.cancel_events: |
| 22 | + self.cancel_events[task_id].set() |
| 23 | + if task_id in self.background_tasks: |
| 24 | + self.background_tasks[task_id].cancel() |
| 25 | + |
| 26 | + def is_canceled(self, task_id: str) -> bool: |
| 27 | + event = self.cancel_events.get(task_id) |
| 28 | + return event is not None and event.is_set() |
| 29 | + |
| 30 | + def cleanup(self, task_id: str) -> None: |
| 31 | + self.cancel_events.pop(task_id, None) |
| 32 | + self.background_tasks.pop(task_id, None) |
0 commit comments