From 3b499ce0075e411f298e795e787789e5a6e755bf Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 27 Jun 2026 21:07:21 +0200 Subject: [PATCH 1/2] Make the listener reconcile LISTEN/UNLISTEN without a lock-order deadlock, end the message-forwarding task cleanly instead of panicking, and bound the listener tests with deterministic waits plus a pytest-timeout backstop. --- pyproject.toml | 11 ++- python/tests/test_listener.py | 152 ++++++++++++++++++++++------------ src/driver/listener/core.rs | 132 +++++++++++++++++++---------- tox.ini | 2 + 4 files changed, 201 insertions(+), 96 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 22351032..e1e0bef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,9 @@ lint = [ "ruff" ] test = [ - "pytest" + "pytest", + "pytest-timeout", + "anyio", ] [tool.maturin] @@ -61,6 +63,13 @@ homepage = "https://github.com/psqlpy-python/psqlpy" repository = "https://github.com/psqlpy-python/psqlpy" documentation = "https://psqlpy-python.github.io/" +[tool.pytest.ini_options] +# Safety net so a wedged test (e.g. a listener waiting on a notification that +# never arrives) fails fast instead of hanging the whole CI job. The listener +# tests bound their own waits well below this; this is the backstop for anything +# unforeseen. +timeout = 120 + [tool.mypy] strict = true mypy_path = "python" diff --git a/python/tests/test_listener.py b/python/tests/test_listener.py index 8db12ae9..5fcf2c0e 100644 --- a/python/tests/test_listener.py +++ b/python/tests/test_listener.py @@ -3,6 +3,7 @@ import asyncio import typing +import anyio import pytest from psqlpy.exceptions import ListenerStartError @@ -15,6 +16,15 @@ TEST_CHANNEL = "test_channel" TEST_PAYLOAD = "test_payload" +# How long helpers wait for an asynchronous condition before giving up. +# These bound every wait in this module so a lost/late NOTIFY surfaces as a +# fast, explicit failure instead of hanging the whole test session forever +# (which is what used to wedge the GitHub runners). +WAIT_TIMEOUT = 10.0 +POLL_INTERVAL = 0.05 +# Time we allow a notification to be (not) delivered before asserting absence. +SETTLE_TIMEOUT = 1.0 + async def construct_listener( psql_pool: ConnectionPool, @@ -54,44 +64,79 @@ async def callback( return callback +async def wait_until_listening( + listener: Listener, + *channels: str, +) -> None: + """Block until the listener's backend session is subscribed to ``channels``. + + ``Listener.listen()`` and async iteration issue the ``LISTEN`` statements + lazily from a background task, so a ``NOTIFY`` sent immediately afterwards + can race ahead of the subscription and be lost. Polling + ``pg_listening_channels()`` on the listener's own connection removes that + race deterministically. + """ + wanted = set(channels) + with anyio.fail_after(WAIT_TIMEOUT): + while True: + result = await listener.connection.execute( + "SELECT pg_listening_channels() AS channel", + ) + active = {row["channel"] for row in result.result()} + if wanted <= active: + return + await asyncio.sleep(POLL_INTERVAL) + + async def notify( psql_pool: ConnectionPool, channel: str = TEST_CHANNEL, - with_delay: bool = False, ) -> None: - if with_delay: - await asyncio.sleep(0.5) - connection = await psql_pool.connection() - await connection.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'") - connection.close() + try: + await connection.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'") + finally: + connection.close() -async def check_insert_callback( +async def wait_for_callback( psql_pool: ConnectionPool, listener_table_name: str, - is_insert_exist: bool = True, number_of_data: int = 1, -) -> None: - connection = await psql_pool.connection() - test_data_seq = ( - await connection.execute( - f"SELECT * FROM {listener_table_name}", - ) - ).result() +) -> list[dict[str, typing.Any]]: + """Poll the result table until the callback has inserted ``number_of_data`` rows.""" + with anyio.fail_after(WAIT_TIMEOUT): + while True: + rows = await read_test_table(psql_pool, listener_table_name) + if len(rows) >= number_of_data: + return rows + await asyncio.sleep(POLL_INTERVAL) - if is_insert_exist: - assert len(test_data_seq) == number_of_data - else: - assert not len(test_data_seq) - return - data_record = test_data_seq[0] +async def read_test_table( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> list[dict[str, typing.Any]]: + connection = await psql_pool.connection() + try: + return ( + await connection.execute( + f"SELECT * FROM {listener_table_name}", + ) + ).result() + finally: + connection.close() - assert data_record["payload"] == TEST_PAYLOAD - assert data_record["channel"] == TEST_CHANNEL - connection.close() +async def assert_no_callback( + psql_pool: ConnectionPool, + listener_table_name: str, + settle: float = SETTLE_TIMEOUT, +) -> None: + """Give a notification time to (not) be delivered, then assert nothing landed.""" + await asyncio.sleep(settle) + rows = await read_test_table(psql_pool, listener_table_name) + assert not rows async def clear_test_table( @@ -118,13 +163,15 @@ async def test_listener_listen( await listener.startup() listener.listen() + await wait_until_listening(listener, TEST_CHANNEL) await notify(psql_pool=psql_pool) - await asyncio.sleep(0.5) - await check_insert_callback( + rows = await wait_for_callback( psql_pool=psql_pool, listener_table_name=listener_table_name, ) + assert rows[0]["payload"] == TEST_PAYLOAD + assert rows[0]["channel"] == TEST_CHANNEL await listener.shutdown() @@ -140,17 +187,21 @@ async def test_listener_asynciterator( ) await listener.startup() - asyncio.create_task( # noqa: RUF006 - notify( - psql_pool=psql_pool, - with_delay=True, - ), - ) - - async for listener_msg in listener: - assert listener_msg.channel == TEST_CHANNEL - assert listener_msg.payload == TEST_PAYLOAD - break + async def trigger() -> None: + # Iteration subscribes lazily inside ``__anext__``; wait for the + # subscription before notifying so the message can't be lost. + await wait_until_listening(listener, TEST_CHANNEL) + await notify(psql_pool=psql_pool) + + # Bound the iteration: if the notification is never delivered this fails + # fast instead of blocking the session forever. + with anyio.fail_after(WAIT_TIMEOUT): + async with anyio.create_task_group() as task_group: + task_group.start_soon(trigger) + async for listener_msg in listener: + assert listener_msg.channel == TEST_CHANNEL + assert listener_msg.payload == TEST_PAYLOAD + break await listener.shutdown() @@ -167,10 +218,10 @@ async def test_listener_abort( await listener.startup() listener.listen() + await wait_until_listening(listener, TEST_CHANNEL) await notify(psql_pool=psql_pool) - await asyncio.sleep(0.5) - await check_insert_callback( + await wait_for_callback( psql_pool=psql_pool, listener_table_name=listener_table_name, ) @@ -183,12 +234,10 @@ async def test_listener_abort( ) await notify(psql_pool=psql_pool) - await asyncio.sleep(0.5) - await check_insert_callback( + await assert_no_callback( psql_pool=psql_pool, listener_table_name=listener_table_name, - is_insert_exist=False, ) @@ -215,8 +264,11 @@ async def test_listener_double_start_exc( ) await listener.startup() - with pytest.raises(expected_exception=ListenerStartError): - await listener.startup() + try: + with pytest.raises(expected_exception=ListenerStartError): + await listener.startup() + finally: + await listener.shutdown() @pytest.mark.usefixtures("create_table_for_listener_tests") @@ -238,14 +290,14 @@ async def test_listener_more_than_one_callback( await listener.startup() listener.listen() + await wait_until_listening(listener, TEST_CHANNEL, additional_channel) for channel in [TEST_CHANNEL, additional_channel]: await notify( psql_pool=psql_pool, channel=channel, ) - await asyncio.sleep(0.5) - await check_insert_callback( + await wait_for_callback( psql_pool=psql_pool, listener_table_name=listener_table_name, number_of_data=2, @@ -282,12 +334,10 @@ async def test_listener_clear_callbacks( ) await notify(psql_pool=psql_pool) - await asyncio.sleep(0.5) - await check_insert_callback( + await assert_no_callback( psql_pool=psql_pool, listener_table_name=listener_table_name, - is_insert_exist=False, ) await listener.shutdown() @@ -309,12 +359,10 @@ async def test_listener_clear_all_callbacks( await listener.clear_all_channels() await notify(psql_pool=psql_pool) - await asyncio.sleep(0.5) - await check_insert_callback( + await assert_no_callback( psql_pool=psql_pool, listener_table_name=listener_table_name, - is_insert_exist=False, ) await listener.shutdown() diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 7d37679f..28fcbef2 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -1,6 +1,6 @@ -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use futures::{stream, FutureExt, StreamExt, TryStreamExt}; +use futures::{pin_mut, stream, StreamExt}; use futures_channel::mpsc::UnboundedReceiver; use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; use postgres_openssl::MakeTlsConnector; @@ -38,7 +38,9 @@ pub struct Listener { listen_abort_handler: Option, connection: Connection, receiver: Option>>>, - listen_query: Arc>, + /// Channels currently subscribed on the backend session. Diffed against the + /// desired set (`channel_callbacks`) to reconcile LISTEN/UNLISTEN. + applied_channels: Arc>>, is_listened: Arc>, is_started: bool, } @@ -58,29 +60,20 @@ impl Listener { listen_abort_handler: Option::default(), connection: Connection::new(None, None, pg_config.clone()), receiver: Option::default(), - listen_query: Arc::default(), + applied_channels: Arc::default(), is_listened: Arc::new(RwLock::new(false)), is_started: false, } } - async fn update_listen_query(&self) { - let read_channel_callbacks = self.channel_callbacks.read().await; - - let channels = read_channel_callbacks.retrieve_all_channels(); - - let mut final_query: String = String::default(); - - for channel_name in channels { - final_query.push_str(format!("LISTEN {channel_name};").as_str()); - } - - let mut write_listen_query = self.listen_query.write().await; - let mut write_is_listened = self.is_listened.write().await; - - write_listen_query.clear(); - write_listen_query.push_str(&final_query); - *write_is_listened = false; + /// Flag that the backend subscriptions no longer match the desired channel + /// set, so the next `execute_listen` reconciles them (LISTEN/UNLISTEN). + /// + /// Only `is_listened` is taken here, never while holding `channel_callbacks`, + /// so there is no lock-order cycle with `execute_listen` (which takes + /// `is_listened` first, then reads `channel_callbacks`). + async fn mark_subscriptions_dirty(&self) { + *self.is_listened.write().await = false; } } @@ -146,13 +139,20 @@ impl Listener { }; let is_listened_clone = self.is_listened.clone(); - let listen_query_clone = self.listen_query.clone(); + let channel_callbacks_clone = self.channel_callbacks.clone(); + let applied_channels_clone = self.applied_channels.clone(); let connection = self.connection.clone(); let py_future = Python::with_gil(move |gil| { rustdriver_future(gil, async move { { - execute_listen(&is_listened_clone, &listen_query_clone, &client).await?; + execute_listen( + &is_listened_clone, + &channel_callbacks_clone, + &applied_channels_clone, + &client, + ) + .await?; }; let next_element = { let mut write_receiver = receiver.write().await; @@ -214,15 +214,25 @@ impl Listener { let (transmitter, receiver) = futures_channel::mpsc::unbounded::(); - let stream = - stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); - - let connection = stream.forward(transmitter).map(|r| { - r.map_err(|_| { - RustPSQLDriverError::ListenerStartError("Cannot startup the listener".into()) - }) - }); - tokio_runtime().spawn(connection); + let forward_messages = async move { + let stream = stream::poll_fn(move |cx| connection.poll_message(cx)); + pin_mut!(stream); + + while let Some(message) = stream.next().await { + match message { + // Receiver gone (listener shut down) -> stop forwarding. + Ok(async_message) => { + if transmitter.unbounded_send(async_message).is_err() { + break; + } + } + // Connection closed or errored -> end the task cleanly + // instead of panicking the worker thread. + Err(_) => break, + } + } + }; + tokio_runtime().spawn(forward_messages); self.receiver = Some(Arc::new(RwLock::new(receiver))); self.connection = Connection::new( @@ -263,7 +273,7 @@ impl Listener { write_channel_callbacks.add_callback(channel, listener_callback); } - self.update_listen_query().await; + self.mark_subscriptions_dirty().await; Ok(()) } @@ -274,7 +284,7 @@ impl Listener { write_channel_callbacks.clear_channel_callbacks(&channel); } - self.update_listen_query().await; + self.mark_subscriptions_dirty().await; } async fn clear_all_channels(&mut self) { @@ -283,7 +293,7 @@ impl Listener { write_channel_callbacks.clear_all(); } - self.update_listen_query().await; + self.mark_subscriptions_dirty().await; } fn listen(&mut self) -> PSQLPyResult<()> { @@ -299,15 +309,21 @@ impl Listener { }; let connection = self.connection.clone(); - let listen_query_clone = self.listen_query.clone(); let is_listened_clone = self.is_listened.clone(); + let applied_channels = self.applied_channels.clone(); let channel_callbacks = self.channel_callbacks.clone(); let jh: JoinHandle> = tokio_runtime().spawn(async move { loop { { - execute_listen(&is_listened_clone, &listen_query_clone, &client).await?; + execute_listen( + &is_listened_clone, + &channel_callbacks, + &applied_channels, + &client, + ) + .await?; }; let next_element = { @@ -358,21 +374,51 @@ async fn dispatch_callback( Ok(()) } +/// Reconcile the backend subscriptions with the desired channel set. +/// +/// When `is_listened` is dirty (`false`) the difference between the desired +/// channels (`channel_callbacks`) and the channels already applied on the +/// backend (`applied_channels`) is turned into `UNLISTEN`/`LISTEN` statements +/// and executed. Re-subscribing is idempotent, so a redundant `LISTEN` is +/// harmless; the `UNLISTEN` half is what stops a cleared channel from delivering. +/// +/// Lock order is `client` -> `is_listened` -> `channel_callbacks` -> +/// `applied_channels`. `mark_subscriptions_dirty` only ever takes `is_listened` +/// (never while holding `channel_callbacks`), so the two cannot deadlock. async fn execute_listen( is_listened: &Arc>, - listen_query: &Arc>, + channel_callbacks: &Arc>, + applied_channels: &Arc>>, client: &Arc>, ) -> PSQLPyResult<()> { let read_conn_g = client.read().await; let mut write_is_listened = is_listened.write().await; - if !write_is_listened.eq(&true) { - let listen_q = { - let read_listen_query = listen_query.read().await; - String::from(read_listen_query.as_str()) + if !*write_is_listened { + let desired: HashSet = { + let read_channel_callbacks = channel_callbacks.read().await; + read_channel_callbacks + .retrieve_all_channels() + .into_iter() + .cloned() + .collect() }; - read_conn_g.batch_execute(listen_q.as_str()).await?; + let mut applied = applied_channels.write().await; + + let mut reconcile_query = String::new(); + for channel in applied.difference(&desired) { + reconcile_query.push_str(format!("UNLISTEN {channel};").as_str()); + } + for channel in desired.difference(&applied) { + reconcile_query.push_str(format!("LISTEN {channel};").as_str()); + } + + if !reconcile_query.is_empty() { + read_conn_g.batch_execute(reconcile_query.as_str()).await?; + } + + *applied = desired; } *write_is_listened = true; diff --git a/tox.ini b/tox.ini index 014909df..2127378a 100644 --- a/tox.ini +++ b/tox.ini @@ -19,6 +19,7 @@ python = skip_install = true deps = pytest>=7,<8 + pytest-timeout>=2,<3 anyio>=3,<4 maturin>=1,<2 pydantic>=2 @@ -34,6 +35,7 @@ commands = skip_install = true deps = pytest>=7,<8 + pytest-timeout>=2,<3 anyio>=3,<4 maturin>=1,<2 pydantic>=2 From 4fe296195b4f87dff4c490e07a1bde5ea673391c Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 28 Jun 2026 18:04:21 +0200 Subject: [PATCH 2/2] Fix clippy doc_markdown: backtick BytesMut in copy_records doc comment Co-Authored-By: Claude Opus 4.8 (1M context) --- src/driver/common.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/driver/common.rs b/src/driver/common.rs index b78d66af..b6d26946 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -371,7 +371,7 @@ macro_rules! impl_copy_records_method { /// pass Python values directly (the same conversions used by /// `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`. /// - /// The encoder follows asyncpg's algorithm: a single BytesMut + /// The encoder follows asyncpg's algorithm: a single `BytesMut` /// accumulator flushed into 512 KiB (`_COPY_BUFFER_SIZE`) chunks. /// All rows are encoded during the GIL pass; chunks are sent to the /// server in a second pass after the GIL is released.