Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ lint = [
"ruff"
]
test = [
"pytest"
"pytest",
"pytest-timeout",
"anyio",
]

[tool.maturin]
Expand All @@ -61,6 +63,13 @@ homepage = "https://github.com/psqlpy-python/psqlpy"
repository = "https://github.com/psqlpy-python/psqlpy"
documentation = "https://psqlpy-python.github.io/"

[tool.pytest.ini_options]
# Safety net so a wedged test (e.g. a listener waiting on a notification that
# never arrives) fails fast instead of hanging the whole CI job. The listener
# tests bound their own waits well below this; this is the backstop for anything
# unforeseen.
timeout = 120

[tool.mypy]
strict = true
mypy_path = "python"
Expand Down
152 changes: 100 additions & 52 deletions python/tests/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import typing

import anyio
import pytest
from psqlpy.exceptions import ListenerStartError

Expand All @@ -15,6 +16,15 @@
TEST_CHANNEL = "test_channel"
TEST_PAYLOAD = "test_payload"

# How long helpers wait for an asynchronous condition before giving up.
# These bound every wait in this module so a lost/late NOTIFY surfaces as a
# fast, explicit failure instead of hanging the whole test session forever
# (which is what used to wedge the GitHub runners).
WAIT_TIMEOUT = 10.0
POLL_INTERVAL = 0.05
# Time we allow a notification to be (not) delivered before asserting absence.
SETTLE_TIMEOUT = 1.0


async def construct_listener(
psql_pool: ConnectionPool,
Expand Down Expand Up @@ -54,44 +64,79 @@ async def callback(
return callback


async def wait_until_listening(
listener: Listener,
*channels: str,
) -> None:
"""Block until the listener's backend session is subscribed to ``channels``.

``Listener.listen()`` and async iteration issue the ``LISTEN`` statements
lazily from a background task, so a ``NOTIFY`` sent immediately afterwards
can race ahead of the subscription and be lost. Polling
``pg_listening_channels()`` on the listener's own connection removes that
race deterministically.
"""
wanted = set(channels)
with anyio.fail_after(WAIT_TIMEOUT):
while True:
result = await listener.connection.execute(
"SELECT pg_listening_channels() AS channel",
)
active = {row["channel"] for row in result.result()}
if wanted <= active:
return
await asyncio.sleep(POLL_INTERVAL)


async def notify(
psql_pool: ConnectionPool,
channel: str = TEST_CHANNEL,
with_delay: bool = False,
) -> None:
if with_delay:
await asyncio.sleep(0.5)

connection = await psql_pool.connection()
await connection.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'")
connection.close()
try:
await connection.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'")
finally:
connection.close()


async def check_insert_callback(
async def wait_for_callback(
psql_pool: ConnectionPool,
listener_table_name: str,
is_insert_exist: bool = True,
number_of_data: int = 1,
) -> None:
connection = await psql_pool.connection()
test_data_seq = (
await connection.execute(
f"SELECT * FROM {listener_table_name}",
)
).result()
) -> list[dict[str, typing.Any]]:
"""Poll the result table until the callback has inserted ``number_of_data`` rows."""
with anyio.fail_after(WAIT_TIMEOUT):
while True:
rows = await read_test_table(psql_pool, listener_table_name)
if len(rows) >= number_of_data:
return rows
await asyncio.sleep(POLL_INTERVAL)

if is_insert_exist:
assert len(test_data_seq) == number_of_data
else:
assert not len(test_data_seq)
return

data_record = test_data_seq[0]
async def read_test_table(
psql_pool: ConnectionPool,
listener_table_name: str,
) -> list[dict[str, typing.Any]]:
connection = await psql_pool.connection()
try:
return (
await connection.execute(
f"SELECT * FROM {listener_table_name}",
)
).result()
finally:
connection.close()

assert data_record["payload"] == TEST_PAYLOAD
assert data_record["channel"] == TEST_CHANNEL

connection.close()
async def assert_no_callback(
psql_pool: ConnectionPool,
listener_table_name: str,
settle: float = SETTLE_TIMEOUT,
) -> None:
"""Give a notification time to (not) be delivered, then assert nothing landed."""
await asyncio.sleep(settle)
rows = await read_test_table(psql_pool, listener_table_name)
assert not rows


async def clear_test_table(
Expand All @@ -118,13 +163,15 @@ async def test_listener_listen(
await listener.startup()
listener.listen()

await wait_until_listening(listener, TEST_CHANNEL)
await notify(psql_pool=psql_pool)
await asyncio.sleep(0.5)

await check_insert_callback(
rows = await wait_for_callback(
psql_pool=psql_pool,
listener_table_name=listener_table_name,
)
assert rows[0]["payload"] == TEST_PAYLOAD
assert rows[0]["channel"] == TEST_CHANNEL

await listener.shutdown()

Expand All @@ -140,17 +187,21 @@ async def test_listener_asynciterator(
)
await listener.startup()

asyncio.create_task( # noqa: RUF006
notify(
psql_pool=psql_pool,
with_delay=True,
),
)

async for listener_msg in listener:
assert listener_msg.channel == TEST_CHANNEL
assert listener_msg.payload == TEST_PAYLOAD
break
async def trigger() -> None:
# Iteration subscribes lazily inside ``__anext__``; wait for the
# subscription before notifying so the message can't be lost.
await wait_until_listening(listener, TEST_CHANNEL)
await notify(psql_pool=psql_pool)

# Bound the iteration: if the notification is never delivered this fails
# fast instead of blocking the session forever.
with anyio.fail_after(WAIT_TIMEOUT):
async with anyio.create_task_group() as task_group:
task_group.start_soon(trigger)
async for listener_msg in listener:
assert listener_msg.channel == TEST_CHANNEL
assert listener_msg.payload == TEST_PAYLOAD
break

await listener.shutdown()

Expand All @@ -167,10 +218,10 @@ async def test_listener_abort(
await listener.startup()
listener.listen()

await wait_until_listening(listener, TEST_CHANNEL)
await notify(psql_pool=psql_pool)
await asyncio.sleep(0.5)

await check_insert_callback(
await wait_for_callback(
psql_pool=psql_pool,
listener_table_name=listener_table_name,
)
Expand All @@ -183,12 +234,10 @@ async def test_listener_abort(
)

await notify(psql_pool=psql_pool)
await asyncio.sleep(0.5)

await check_insert_callback(
await assert_no_callback(
psql_pool=psql_pool,
listener_table_name=listener_table_name,
is_insert_exist=False,
)


Expand All @@ -215,8 +264,11 @@ async def test_listener_double_start_exc(
)
await listener.startup()

with pytest.raises(expected_exception=ListenerStartError):
await listener.startup()
try:
with pytest.raises(expected_exception=ListenerStartError):
await listener.startup()
finally:
await listener.shutdown()


@pytest.mark.usefixtures("create_table_for_listener_tests")
Expand All @@ -238,14 +290,14 @@ async def test_listener_more_than_one_callback(
await listener.startup()
listener.listen()

await wait_until_listening(listener, TEST_CHANNEL, additional_channel)
for channel in [TEST_CHANNEL, additional_channel]:
await notify(
psql_pool=psql_pool,
channel=channel,
)

await asyncio.sleep(0.5)
await check_insert_callback(
await wait_for_callback(
psql_pool=psql_pool,
listener_table_name=listener_table_name,
number_of_data=2,
Expand Down Expand Up @@ -282,12 +334,10 @@ async def test_listener_clear_callbacks(
)

await notify(psql_pool=psql_pool)
await asyncio.sleep(0.5)

await check_insert_callback(
await assert_no_callback(
psql_pool=psql_pool,
listener_table_name=listener_table_name,
is_insert_exist=False,
)

await listener.shutdown()
Expand All @@ -309,12 +359,10 @@ async def test_listener_clear_all_callbacks(
await listener.clear_all_channels()

await notify(psql_pool=psql_pool)
await asyncio.sleep(0.5)

await check_insert_callback(
await assert_no_callback(
psql_pool=psql_pool,
listener_table_name=listener_table_name,
is_insert_exist=False,
)

await listener.shutdown()
2 changes: 1 addition & 1 deletion src/driver/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ macro_rules! impl_copy_records_method {
/// pass Python values directly (the same conversions used by
/// `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.
///
/// The encoder follows asyncpg's algorithm: a single BytesMut
/// The encoder follows asyncpg's algorithm: a single `BytesMut`
/// accumulator flushed into 512 KiB (`_COPY_BUFFER_SIZE`) chunks.
/// All rows are encoded during the GIL pass; chunks are sent to the
/// server in a second pass after the GIL is released.
Expand Down
Loading
Loading