Skip to content

Commit c7deafd

Browse files
committed
Added tests for broker ctx manager.
1 parent 323d09d commit c7deafd

2 files changed

Lines changed: 108 additions & 1 deletion

File tree

taskiq/abc/broker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def _register_task(
535535
self.local_task_registry[task_name] = task
536536

537537
async def __aenter__(self) -> None:
538-
"""Satarts the broker as ctx manager."""
538+
"""Starts the broker as ctx manager."""
539539
await self.startup()
540540

541541
async def __aexit__(self, *args: object, **kwargs: Any) -> None:

tests/abc/test_broker.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from collections.abc import AsyncGenerator
22
from copy import copy
33

4+
import pytest
5+
46
from taskiq.abc.broker import AsyncBroker
57
from taskiq.decor import AsyncTaskiqDecoratedTask
8+
from taskiq.events import TaskiqEvents
69
from taskiq.message import BrokerMessage
10+
from taskiq.state import TaskiqState
711

812

913
class _TestBroker(AsyncBroker):
@@ -76,3 +80,106 @@ async def test_task() -> None: ...
7680
assert "another_label" in test_kicker.labels
7781

7882
assert test_task.labels == old_labels
83+
84+
85+
@pytest.mark.anyio
86+
async def test_async_context_manager_enter() -> None:
87+
"""Test that __aenter__ calls startup."""
88+
broker = _TestBroker()
89+
startup_called = False
90+
91+
@broker.on_event(TaskiqEvents.CLIENT_STARTUP)
92+
async def track_startup(state: TaskiqState) -> None:
93+
nonlocal startup_called
94+
startup_called = True
95+
96+
async with broker:
97+
assert startup_called is True
98+
99+
100+
@pytest.mark.anyio
101+
async def test_async_context_manager_exit() -> None:
102+
"""Test that __aexit__ calls shutdown."""
103+
broker = _TestBroker()
104+
shutdown_called = False
105+
106+
@broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN)
107+
async def track_shutdown(state: TaskiqState) -> None:
108+
nonlocal shutdown_called
109+
shutdown_called = True
110+
111+
async with broker:
112+
pass
113+
114+
assert shutdown_called is True
115+
116+
117+
@pytest.mark.anyio
118+
async def test_async_context_manager_enter_worker() -> None:
119+
"""Test that __aenter__ calls worker startup when is_worker_process is True."""
120+
broker = _TestBroker()
121+
broker.is_worker_process = True
122+
startup_called = False
123+
124+
@broker.on_event(TaskiqEvents.WORKER_STARTUP)
125+
async def track_startup(state: TaskiqState) -> None:
126+
nonlocal startup_called
127+
startup_called = True
128+
129+
async with broker:
130+
assert startup_called is True
131+
132+
133+
@pytest.mark.anyio
134+
async def test_async_context_manager_exit_worker() -> None:
135+
"""Test that __aexit__ calls worker shutdown when is_worker_process is True."""
136+
broker = _TestBroker()
137+
broker.is_worker_process = True
138+
shutdown_called = False
139+
140+
@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
141+
async def track_shutdown(state: TaskiqState) -> None:
142+
nonlocal shutdown_called
143+
shutdown_called = True
144+
145+
async with broker:
146+
pass
147+
148+
assert shutdown_called is True
149+
150+
151+
@pytest.mark.anyio
152+
async def test_async_context_manager_exit_on_exception() -> None:
153+
"""Test that __aexit__ calls shutdown even if exception is raised."""
154+
broker = _TestBroker()
155+
shutdown_called = False
156+
157+
@broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN)
158+
async def track_shutdown(state: TaskiqState) -> None:
159+
nonlocal shutdown_called
160+
shutdown_called = True
161+
162+
with pytest.raises(ValueError, match="Test exception"):
163+
async with broker:
164+
raise ValueError("Test exception")
165+
166+
assert shutdown_called is True
167+
168+
169+
@pytest.mark.anyio
170+
async def test_async_context_manager_exit_worker_on_exception() -> None:
171+
"""Test that __aexit__ calls worker shutdown even if exception is raised."""
172+
broker = _TestBroker()
173+
broker.is_worker_process = True
174+
shutdown_called = False
175+
176+
@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
177+
async def track_shutdown(state: TaskiqState) -> None:
178+
nonlocal shutdown_called
179+
shutdown_called = True
180+
181+
with pytest.raises(ValueError, match="Test exception"):
182+
async with broker:
183+
raise ValueError("Test exception")
184+
185+
assert shutdown_called is True

0 commit comments

Comments
 (0)